from matplotlib import pyplot as plt
from compare_eva import compare_algorithms
from data_pre import load_data

# 创建一个函数来绘制图表
def plot_results(results):
    
    algorithms = ['ALS-WR']
    metrics = ['RMSE', 'MAE', 'Precision', 'Recall', 'F1']

    # 设置图形的布局
    fig, axes = plt.subplots(len(metrics), len(results), figsize=(15, 10))
    for j, (dataset, res) in enumerate(results.items()):
        for i, metric in enumerate(metrics):
            ax = axes[i, j]
            values = [res[algo][metric] for algo in algorithms]
            ax.bar(algorithms, values)
            ax.set_title(f'{dataset} - {metric}')
            ax.set_ylabel(metric)
            ax.set_ylim([0, max(values) + 1])  # 设置y轴范围

    plt.tight_layout()
    plt.show()


# 进行算法比较
data = load_data()
result = {}
for i,d in data.items():
    print(d)
    result[i]=compare_algorithms(d)
    print(f"draw {i} dataset picture ...")
    plot_results(result[i])